#!/usr/bin/env python3

# Copyright Jamie Iles, 2017
#
# This file is part of s80x86.
#
# s80x86 is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# s80x86 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with s80x86.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import math
import os
import subprocess
import pystache

from functools import partial, lru_cache
from textx.metamodel import metamodel_from_file

from microasm.types import (
    GPR,
    SR,
    ADriver,
    BDriver,
    ALUOp,
    RDSelSource,
    RegWrSource,
    UpdateFlags,
    JumpType,
    MARWrSel,
    TEMPWrSel,
    WidthType
)

HERE = os.path.dirname(__file__)
GRAMMAR = os.path.join(HERE, 'microcode_grammar.g')
TEMPLATE = os.path.join(HERE, '..', '..', 'rtl', 'microcode', 'Microcode.sv.templ')
MIF_TEMPLATE = os.path.join(HERE, '..', '..', 'rtl', 'microcode', 'microcode.mif.templ')
VERILOG_ENUM_TEMPLATE = os.path.join(HERE, '..', '..', 'rtl', 'microcode', 'MicrocodeTypes.sv.templ')
CPP_ENUM_TEMPLATE = os.path.join(HERE, '..', '..', 'rtl', 'microcode', 'MicrocodeTypes.h.templ')
mm = metamodel_from_file(GRAMMAR, skipws='\s\t\n\\', debug=False)

pystache.defaults.DELIMITERS = (u'<%', u'%>')

def num_bits(max_val):
    if max_val == 0:
        return 1
    return int(math.ceil(math.log(max_val, 2)))

def token_type(t):
    return t.__class__.__name__

def field_type(size):
    if size == 1:
        return ''
    return '[{0}:0] '.format(size - 1)

class MicroAssembler(object):
    def __init__(self, infiles, bin_output, mif_output, verilog_output,
                 verilog_types_output, cpp_types_output, include):
        self._infiles = infiles
        self._bin_output = bin_output
        self._mif_output = mif_output
        self._verilog_output = verilog_output
        self._verilog_types_output = verilog_types_output
        self._cpp_types_output = cpp_types_output
        self._immediate_pool = []
        self._update_flags_pool = []
        self._include = include

    def _build_microcode(self, files):
        label_map = {}
        instructions = []

        def assign_addresses():
            assigned_addresses = {}
            last_addr = max([i.address if i.address else 0 for i in instructions])
            for instr in instructions:
                if instr.address is None:
                    instr.address = last_addr + 1
                    last_addr += 1
                if instr.address in assigned_addresses:
                    raise ValueError('Duplicate instruction at address {0:x}'.format(instr.address))
                assigned_addresses[instr.address] = instr

        def resolve_labels():
            for instr in instructions:
                if instr.next_label:
                    instr.next = label_map[instr.next_label].address
                else:
                    instr.next = instr.address + 1

        def preprocess(filename):
            includes = []
            for i in self._include:
                includes.append('-I')
                includes.append(i)
            return subprocess.check_output(['cpp', '-nostdinc', filename, '-o', '-'] + includes).decode('utf-8')

        for f in files:
            model = mm.model_from_str(preprocess(f))
            current_label = None
            current_address = None

            line_offset = 0
            for d in model.lines:
                line, _ = mm.parser.pos_to_linecol(d._tx_position)
                line -= line_offset + 1
                if token_type(d) == 'LabelAnchor':
                    current_label = d.label
                    if current_label in label_map:
                        raise KeyError('Duplicate label "{0}"'.format(current_label))
                elif token_type(d) == 'Directive':
                    if d.directive == 'at':
                        current_address = int(d.arguments[0].value, 0)
                    elif d.directive == 'auto_address':
                        current_address = None
                    else:
                        raise ValueError('Invalid directive "{0}"'.format(d.directive))
                elif token_type(d) == 'MicroInstruction':
                    instr = MicroInstruction(f, line, current_address)
                    instructions.append(instr)
                    self._current_instr = instr

                    self._parse(d.fields)

                    if current_label:
                        label_map[current_label] = instr
                    current_address = None
                    current_label = None
                elif token_type(d) == 'PreprocessorDirective':
                    if d.Filename == f:
                        line, _ = mm.parser.pos_to_linecol(d._tx_position)
                        line_offset = line - d.LineNumber
                else:
                    raise ValueError('Unexpected token type {0}'.format(token_type(d)))

        assign_addresses()
        resolve_labels()
        self.finalize_immediates()
        self.finalize_flags()

        return sorted(instructions, key=lambda i: i.address)

    def _write_bin(self, filename, instructions):
        last_addr = -1
        with open(filename, 'w') as f:
            f.write(self.field_comment())
            for instr in instructions:
                if instr.address != last_addr + 1:
                    f.write('@ {0:x}\n'.format(instr.address))
                f.write('%s // %x %s %s\n' % (instr.encode(num_bits(instructions[-1].address + 1)),
                                              instr.address, instr.origin,
                                              ' '.join(instr.present_fields.keys())))
                last_addr = instr.address

    def _write_mif(self, filename, instructions):
        address_bits = num_bits(instructions[-1].address + 1)
        instr_list = [{'address': '%x' % (i.address,),
                       'value': i.encode(num_bits(instructions[-1].address + 1))}
                      for i in instructions]
        data = {
            'num_instructions': instructions[-1].address + 1,
            'width': self.width(address_bits),
            'instructions': instr_list,
        }
        with open(MIF_TEMPLATE) as f:
            template = f.read()
        with open(filename, 'w') as outfile:
            outfile.write(pystache.render(template, data))

    def _write_microcode_verilog(self, filename, instructions):
        address_bits = num_bits(instructions[-1].address + 1)
        data = {
            'fields': MicroAssembler.bitfields(address_bits),
            'exported_fields': MicroAssembler.exported_fields(),
            'num_instructions': instructions[-1].address + 1,
            'addr_bits': num_bits(instructions[-1].address + 1),
            'immediates': MicroAssembler.get_immediate_dict(),
            'flags': MicroAssembler.get_flags_dict(),
        }
        with open(TEMPLATE) as f:
            template = f.read()
        with open(filename, 'w') as outfile:
            outfile.write(pystache.render(template, data))

    def _write_types(self, verilog_filename, cpp_filename):
        enums = self._gather_enums()

        with open(VERILOG_ENUM_TEMPLATE) as f:
            template = f.read()
        with open(verilog_filename, 'w') as f:
            f.write(pystache.render(template, {'enums': enums}))

        with open(CPP_ENUM_TEMPLATE) as f:
            template = f.read()
        with open(cpp_filename, 'w') as f:
            f.write(pystache.render(template, {'enums': enums}))

    @staticmethod
    def _gather_enums():
        enums = []
        for cls in [ADriver, BDriver, ALUOp, JumpType, RDSelSource,
                    MARWrSel, TEMPWrSel, UpdateFlags, RegWrSource, WidthType]:
            items = []
            for idx, e in enumerate(cls.__members__):
                name = '%s_%s' % (cls.__name__, e)
                value = '%d%s' % (cls[e].value, ',' if idx != len(cls.__members__) - 1 else '')
                items.append({'name': name, 'value': value})

            type_data = {
                'high_bit': num_bits(len(cls)) - 1,
                'name': 'MC_%s_t' % (cls.__name__,),
                'items': items,
                'num_bits': num_bits(len(cls)),
            }
            enums.append(type_data)
        return enums

    def _enumerated_field(self, field, enum_class=None):
        val = enum_class[field.arguments[0]].value
        self._current_instr.present_fields[field.name] = val

    def _boolean(self, field):
        self._current_instr.present_fields[field.name] = 1

    def _update_flags(self, field):
        flags = 0
        for arg in field.arguments:
            flags |= (1 << UpdateFlags[arg].value)

        if flags not in self._update_flags_pool:
            self._update_flags_pool.append(flags)
        self._current_instr.present_fields['update_flags'] = flags

    def _jump(self, field, jump_type=None):
        if jump_type != JumpType.OPCODE:
            self._current_instr.next_label = field.arguments[0]
        self._current_instr._has_jump = True
        self._current_instr.present_fields['jump_type'] = jump_type.value

    def _reserved(self, field):
        raise ValueError('Reserved keyword {0}'.format(field.name))

    def _parse(self, fields):
        for field in fields:
            handler = self.microcode_fields[field.name][1]
            handler(self, field)

    def _immediate(self, field):
        if len(field.arguments) != 1:
            raise ValueError('Only one immediate value supported')
        if int(field.arguments[0].value, 16) not in self._immediate_pool:
            self._immediate_pool.append(int(field.arguments[0].value, 16))
        self._current_instr.present_fields['immediate'] = int(field.arguments[0].value, 16)

    def finalize_immediates(self):
        MicroAssembler.immediate_dict = {0: 0}
        for idx, val in enumerate(self._immediate_pool):
            MicroAssembler.immediate_dict[val] = idx + 1
        self.microcode_fields['immediate'] = (num_bits(len(MicroAssembler.immediate_dict.keys()) + 1),
                                              self._immediate, False)

    def finalize_flags(self):
        MicroAssembler.update_flags_dict = {0: 0}
        for idx, val in enumerate(self._update_flags_pool):
            MicroAssembler.update_flags_dict[val] = idx + 1
        self.microcode_fields['update_flags'] = (num_bits(len(MicroAssembler.update_flags_dict.keys()) + 1),
                                                 self._update_flags, False)

    """List of microcode fields:
      (number of bits, handler, exported to pipeline boolean.
    """
    microcode_fields = {
        'ra_sel': (num_bits(len(GPR)), partial(_enumerated_field, enum_class=GPR), True),
        'rb_cl': (1, _boolean, True),
        'rd_sel': (num_bits(len(GPR)), partial(_enumerated_field, enum_class=GPR), True),
        'next_instruction': (1, _boolean, True),
        'mar_wr_sel': (num_bits(len(MARWrSel)), partial(_enumerated_field, enum_class=MARWrSel), True),
        'mar_write': (1, _boolean, True),
        'mdr_write': (1, _boolean, True),
        'mem_read': (1, _boolean, True),
        'mem_write': (1, _boolean, True),
        'segment': (num_bits(len(SR)), partial(_enumerated_field, enum_class=SR), True),
        'segment_wr_en': (1, _boolean, True),
        'segment_force': (1, _boolean, True),
        'a_sel': (num_bits(len(ADriver)), partial(_enumerated_field, enum_class=ADriver), True),
        'b_sel': (num_bits(len(BDriver)), partial(_enumerated_field, enum_class=BDriver), True),
        'alu_op': (num_bits(len(ALUOp)), partial(_enumerated_field, enum_class=ALUOp), True),
        'update_flags': (len(UpdateFlags), _update_flags, False),
        'ra_modrm_rm_reg': (1, _boolean, True),
        'rd_sel_source': (num_bits(len(RDSelSource)), partial(_enumerated_field, enum_class=RDSelSource), True),
        'reg_wr_source': (num_bits(len(RegWrSource)), partial(_enumerated_field, enum_class=RegWrSource), True),
        'tmp_wr_en': (1, _boolean, True),
        'tmp_wr_sel': (num_bits(len(TEMPWrSel)), partial(_enumerated_field, enum_class=TEMPWrSel), True),
        'width': (num_bits(len(WidthType)), partial(_enumerated_field, enum_class=WidthType), False),
        'load_ip': (1, _boolean, True),
        'io': (1, _boolean, True),
        'ext_int_yield': (1, _boolean, True),
        'ext_int_inhibit': (1, _boolean, False),
        'immediate': (0, _immediate, False),
        # Internal keywords, generated from others
        'jump_type': (num_bits(len(JumpType)), _reserved, False),
        # Populated once we know how many address bits required
        'jump_target': (0, _reserved, False),
        # Keywords only, internally mapped to other microcode fields
        'jmp_opcode': (0, partial(_jump, jump_type=JumpType.OPCODE), False),
        'jmp_rm_reg_mem': (0, partial(_jump, jump_type=JumpType.RM_REG_MEM), False),
        'jmp_dispatch_reg': (0, partial(_jump, jump_type=JumpType.DISPATCH_REG), False),
        'jmp_if_not_rep': (0, partial(_jump, jump_type=JumpType.HAS_NO_REP_PREFIX), False),
        'jmp_if_zero': (0, partial(_jump, jump_type=JumpType.ZERO), False),
        'jmp_if_rep_not_complete': (0, partial(_jump, jump_type=JumpType.REP_NOT_COMPLETE), False),
        'jmp_if_taken': (0, partial(_jump, jump_type=JumpType.JUMP_TAKEN), False),
        'jmp_rb_zero': (0, partial(_jump, jump_type=JumpType.RB_ZERO), False),
        'jmp_loop_done': (0, partial(_jump, jump_type=JumpType.LOOP_DONE), False),
        'jmp': (0, partial(_jump, jump_type=JumpType.UNCONDITIONAL), False),
    }

    @staticmethod
    def sorted_field_keys():
        return sorted(MicroAssembler.microcode_fields.keys())

    @staticmethod
    def exported_field_keys():
        return sorted(filter(lambda f: MicroAssembler.microcode_fields[f][2],
                             MicroAssembler.microcode_fields.keys()))

    @staticmethod
    def field_comment():
        fields = list(filter(lambda f: MicroAssembler.microcode_fields[f][0] > 0,
                             MicroAssembler.sorted_field_keys()))
        return '//' + ' '.join(['next'] + fields) + '\n'

    @staticmethod
    def width(address_bits):
        w = address_bits
        for name in MicroAssembler.sorted_field_keys():
            w += MicroAssembler.microcode_fields[name][0]
        return w

    @staticmethod
    def bitfields(address_bits):
        fields = []
        fields.append({
            'type': lambda: field_type(address_bits),
            'name': 'next'
        })
        for name in MicroAssembler.sorted_field_keys():
            size = MicroAssembler.microcode_fields[name][0]
            if size != 0:
                fields.append({
                    'type': partial(field_type, size),
                    'name': name
                })
        return fields

    @staticmethod
    def get_immediate_dict():
        d = []
        for k, v in MicroAssembler.immediate_dict.items():
            d.append({'idx': v, 'val': '{0:x}'.format(k)})
        return d

    @staticmethod
    def get_flags_dict():
        d = []
        for k, v in MicroAssembler.update_flags_dict.items():
            d.append({'idx': v, 'val': '{0:x}'.format(k)})
        return d

    @staticmethod
    def exported_fields():
        fields = []
        for name in MicroAssembler.exported_field_keys():
            size = MicroAssembler.microcode_fields[name][0]
            if size != 0:
                fields.append({
                    'type': partial(field_type, size),
                    'name': name
                })
        return fields

    def assemble(self):
        instructions = self._build_microcode(self._infiles)
        self._write_bin(self._bin_output, instructions)
        self._write_mif(self._mif_output, instructions)
        self._write_microcode_verilog(self._verilog_output, instructions)
        self._write_types(self._verilog_types_output, self._cpp_types_output)

class MicroInstruction(object):
    def __init__(self, filename, line, address=None):
        self.next = None
        self.next_label = None
        self.address = address
        self._has_jump = False
        self.present_fields = {}
        self.origin = '%s:%d' % (filename, line)

    @lru_cache(maxsize=None)
    def encode(self, address_bits):
        if 'immediate' in self.present_fields:
            self.present_fields['immediate'] = MicroAssembler.immediate_dict[self.present_fields['immediate']]
        if 'update_flags' in self.present_fields:
            self.present_fields['update_flags'] = MicroAssembler.update_flags_dict[self.present_fields['update_flags']]
        for name, value in self.present_fields.items():
            assert num_bits(value) <= MicroAssembler.microcode_fields[name][0]

        if self._has_jump:
            self.present_fields['jump_target'] = self.next

        field_values = []
        field_values.append('{0:0{width}b}'.format(self.next, width=address_bits))

        for name in MicroAssembler.sorted_field_keys():
            width, _, _ = MicroAssembler.microcode_fields[name]
            if width == 0:
                continue

            value = self.present_fields.get(name, 0)
            field_values.append('{0:0{width}b}'.format(value, width=width))
        return ''.join(field_values)

parser = argparse.ArgumentParser()
parser.add_argument('--include', '-I', action='append', default=[])
parser.add_argument('bin_output')
parser.add_argument('mif_output')
parser.add_argument('verilog_output')
parser.add_argument('verilog_types_output')
parser.add_argument('cpp_types_output')
parser.add_argument('input', nargs='+')
args = parser.parse_args()

assembler = MicroAssembler(args.input, args.bin_output, args.mif_output,
                           args.verilog_output, args.verilog_types_output,
                           args.cpp_types_output, args.include)
assembler.assemble()
